package pkg

import (
	"fmt"
	"github.com/gin-gonic/gin"
	"github.com/google/uuid"
	"net"
	"time"
)

type Session interface {
	Create(*gin.Context, string, string, time.Duration) error
	Get(*gin.Context, string) (string, error)
	Delete(*gin.Context, string) error
}

type SessionWithRedis struct {
	client       *Redis
	prefix       string
	cookieName   string
	cookieDomain string
	secure       bool
	httpOnly     bool
}

func NewSession(client *Redis, prefix string, cookieName string, cookieDomain string, secure bool, httpOnly bool) *SessionWithRedis {
	return &SessionWithRedis{
		client:       client,
		prefix:       prefix,
		cookieName:   cookieName,
		cookieDomain: cookieDomain,
		secure:       secure,
		httpOnly:     httpOnly,
	}
}

// Create 创建 session
// sessionName 用于区分不同 session，比如用户 session 与管理员 session
// data 存储在 session 中的数据
// ttl session 有效期
func (s *SessionWithRedis) Create(c *gin.Context, sessionName string, data string, ttl time.Duration) error {
	sessionID, exists := c.Get(s.cookieName)
	if !exists {
		var err error
		sessionID, err = c.Cookie(s.cookieName)
		if err != nil {
			sessionID = uuid.New().String()
			c.SetCookie(s.cookieName, sessionID.(string), int(ttl.Seconds()), "/", s.getCookieDomain(c), s.secure, s.httpOnly)
		}
		c.Set(s.cookieName, sessionID)
	}
	sessionKey := s.getSessionKey(sessionID.(string), sessionName)
	return s.client.Set(c.Request.Context(), sessionKey, data, ttl)
}

func (s *SessionWithRedis) Get(c *gin.Context, sessionName string) (string, error) {
	sessionID, err := c.Cookie(s.cookieName)
	if err != nil {
		return "", fmt.Errorf("failed to get cookie: %w", err)
	}
	sessionKey := s.getSessionKey(sessionID, sessionName)
	result, err := s.client.Get(c.Request.Context(), sessionKey).Result()
	if err != nil {
		return "", fmt.Errorf("failed to get session from redis: %w", err)
	}
	return result, nil
}

func (s *SessionWithRedis) Delete(c *gin.Context, sessionName string) error {
	sessionID, err := c.Cookie(s.cookieName)
	if err != nil {
		return fmt.Errorf("failed to get cookie: %w", err)
	}
	sessionKey := s.getSessionKey(sessionID, sessionName)
	if err = s.client.Del(c.Request.Context(), sessionKey); err != nil {
		return err
	}
	c.SetCookie(s.cookieName, "", -1, "/", s.getCookieDomain(c), s.secure, s.httpOnly)
	return nil
}

func (s *SessionWithRedis) getSessionKey(sessionID string, sessionName string) string {
	return s.prefix + sessionID + sessionName
}

func (s *SessionWithRedis) getCookieDomain(c *gin.Context) string {
	if s.cookieDomain != "" {
		return s.cookieDomain
	}
	host, _, err := net.SplitHostPort(c.Request.Host)
	if err != nil {
		// 如果出错（例如没有端口号），直接返回原始Host
		return c.Request.Host
	}
	return host
}
